# ------------------------------------------------------------------------------
# Plots correlation between features in stent/test models
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q medium -R "rusage[mem=1000]" bash 08_plot_feature_outcome_cor.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(tidyverse)
library(glue)
library(ggthemes)
library(Metrics)
library(viridis)
library(ggrepel)

temp <- here(
  "code", "06_physician_boundedness", "01_behavioral_lasso", "temp"
)

a <- modules::use(here("lib", "aesthetics.R"))
a$get_font("Optima", here("lib", "optima.ttf"))
u <- modules::use(here("lib", "util.R"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
# cor_tb_raw <- read_csv(file.path(temp, "feature_outcome_cor_by_selection.csv"))
feature_labs <- read_csv(
  here(
    "code", "03_analysis", "07_limited_bandwidth_models", "corr_var_labels.csv"
  )
)
gg_tb <- read_csv("corr_vars_k49.csv") %>%
  u$safe_left_join(feature_labs) %>%
  # mutate(feature_lab = ifelse(is.na(feature_lab), feature,feature_lab)) %>%
  mutate(
    `Feature Type` = factor(
      group, levels = c("symptom", "dem", "other"),
      labels = c("Representative Symptoms", "Demographics", "All Others")
    )
  )

high_resid <- quantile(gg_tb$abs_residual, 0.75)
gg_tb <- gg_tb %>%
  mutate(
    feature_lab = case_when(
      abs_residual >= high_resid ~ feature_lab,
      feature == "dem_agi_under_25K" ~ feature_lab,
      TRUE ~ ""
    )
  )

# Plot
message("Plotting...")
corr_txt <- bquote(R^2 ~ ":" ~ .(0.433))

gg <- ggplot(
  gg_tb,
  aes(
    x = cor_stent_or_cabg,
    y = cor_test, label = feature_lab#,
    # color = `Feature Type`, shape = `Feature Type`
  )
) +
  labs(
    x = "Correlation with Risk",
    y = "Correlation with Test"
  ) +
  theme_bw() +
  geom_vline(xintercept = 0, linetype = "dotted") +
  geom_hline(yintercept = 0, linetype = "dotted") +
  geom_smooth(
    method = "lm", color = a$disc_palette[1], fill = a$disc_palette[1],
    alpha = 0.25,
  ) +
  annotate(
    "text",
    x = -0.075, y = 0.2, label = corr_txt, family = "Optima",
    fontface = "bold", size = 18
  ) +
  geom_label_repel(size = 14, color = "black", family = "Optima") +
  geom_point(size = 3, aes(shape = `Feature Type`, color = `Feature Type`)) +
  theme(
    legend.position = "bottom",
    text = element_text(family = "Optima", size = 40)
  ) +
  scale_color_manual(values = a$disc_palette[c(3,2,1)]) +
  scale_fill_manual(values = a$disc_palette[c(3,2,1)])


ggsave(
  file.path(temp, "LASSO_SCATTER_BY_STENT_OR_CABG_010_DAY.png"),
  width = 10, height = 7, units = "in"
)

message("Done.")
